home *** CD-ROM | disk | FTP | other *** search
/ Total Network Tools 2002 / NextStepPublishing-TotalNetworkTools2002-Win95.iso / Archive / Misc Servers / Zope.exe / RESOLVER.PY < prev    next >
Encoding:
Python Source  |  2000-06-02  |  13.1 KB  |  443 lines

  1. # -*- Mode: Python; tab-width: 4 -*-
  2.  
  3. #
  4. #    Author: Sam Rushing <rushing@nightmare.com>
  5. #
  6.  
  7. RCS_ID =  '$Id: resolver.py,v 1.6 2000/06/02 14:22:48 brian Exp $'
  8.  
  9.  
  10. # Fast, low-overhead asynchronous name resolver.  uses 'pre-cooked'
  11. # DNS requests, unpacks only as much as it needs of the reply.
  12.  
  13. # see rfc1035 for details
  14.  
  15. import string
  16. import asyncore
  17. import socket
  18. import sys
  19. import time
  20. from counter import counter
  21.  
  22. VERSION = string.split(RCS_ID)[2]
  23.  
  24. # header
  25. #                                    1  1  1  1  1  1
  26. #      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
  27. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  28. #    |                      ID                       |
  29. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  30. #    |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
  31. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  32. #    |                    QDCOUNT                    |
  33. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  34. #    |                    ANCOUNT                    |
  35. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  36. #    |                    NSCOUNT                    |
  37. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  38. #    |                    ARCOUNT                    |
  39. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  40.  
  41.  
  42. # question
  43. #                                    1  1  1  1  1  1
  44. #      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
  45. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  46. #    |                                               |
  47. #    /                     QNAME                     /
  48. #    /                                               /
  49. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  50. #    |                     QTYPE                     |
  51. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  52. #    |                     QCLASS                    |
  53. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  54.  
  55. # build a DNS address request, _quickly_
  56. def fast_address_request (host, id=0):
  57.     return (
  58.         '%c%c' % (chr((id>>8)&0xff),chr(id&0xff))
  59.         + '\001\000\000\001\000\000\000\000\000\000%s\000\000\001\000\001' % (
  60.             string.join (
  61.                 map (
  62.                     lambda part: '%c%s' % (chr(len(part)),part),
  63.                     string.split (host, '.')
  64.                     ), ''
  65.                 )
  66.             )
  67.         )
  68.  
  69. def fast_ptr_request (host, id=0):
  70.     return (
  71.         '%c%c' % (chr((id>>8)&0xff),chr(id&0xff))
  72.         + '\001\000\000\001\000\000\000\000\000\000%s\000\000\014\000\001' % (
  73.             string.join (
  74.                 map (
  75.                     lambda part: '%c%s' % (chr(len(part)),part),
  76.                     string.split (host, '.')
  77.                     ), ''
  78.                 )
  79.             )
  80.         )
  81.  
  82. def unpack_name (r,pos):
  83.     n = []
  84.     while 1:
  85.         ll = ord(r[pos])
  86.         if (ll&0xc0):
  87.             # compression
  88.             pos = (ll&0x3f << 8) + (ord(r[pos+1]))
  89.         elif ll == 0:
  90.             break            
  91.         else:
  92.             pos = pos + 1
  93.             n.append (r[pos:pos+ll])
  94.             pos = pos + ll
  95.     return string.join (n,'.')
  96.  
  97. def skip_name (r,pos):
  98.     s = pos
  99.     while 1:
  100.         ll = ord(r[pos])
  101.         if (ll&0xc0):
  102.             # compression
  103.             return pos + 2
  104.         elif ll == 0:
  105.             pos = pos + 1
  106.             break
  107.         else:
  108.             pos = pos + ll + 1
  109.     return pos
  110.  
  111. def unpack_ttl (r,pos):
  112.     return reduce (
  113.         lambda x,y: (x<<8)|y,
  114.         map (ord, r[pos:pos+4])
  115.         )
  116.  
  117. # resource record
  118. #                                    1  1  1  1  1  1
  119. #      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
  120. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  121. #    |                                               |
  122. #    /                                               /
  123. #    /                      NAME                     /
  124. #    |                                               |
  125. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  126. #    |                      TYPE                     |
  127. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  128. #    |                     CLASS                     |
  129. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  130. #    |                      TTL                      |
  131. #    |                                               |
  132. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  133. #    |                   RDLENGTH                    |
  134. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
  135. #    /                     RDATA                     /
  136. #    /                                               /
  137. #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  138.  
  139. def unpack_address_reply (r):
  140.     ancount = (ord(r[6])<<8) + (ord(r[7]))
  141.     # skip question, first name starts at 12,
  142.     # this is followed by QTYPE and QCLASS
  143.     pos = skip_name (r, 12) + 4
  144.     if ancount:
  145.         # we are looking very specifically for
  146.         # an answer with TYPE=A, CLASS=IN (\000\001\000\001)
  147.         for an in range(ancount):
  148.             pos = skip_name (r, pos)
  149.             if r[pos:pos+4] == '\000\001\000\001':
  150.                 return (
  151.                     unpack_ttl (r,pos+4),
  152.                     '%d.%d.%d.%d' % tuple(map(ord,r[pos+10:pos+14]))
  153.                     )
  154.             # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA
  155.             pos = pos + 8
  156.             rdlength = (ord(r[pos])<<8) + (ord(r[pos+1]))
  157.             pos = pos + 2 + rdlength
  158.         return 0, None
  159.     else:
  160.         return 0, None
  161.  
  162. def unpack_ptr_reply (r):
  163.     ancount = (ord(r[6])<<8) + (ord(r[7]))
  164.     # skip question, first name starts at 12,
  165.     # this is followed by QTYPE and QCLASS
  166.     pos = skip_name (r, 12) + 4
  167.     if ancount:
  168.         # we are looking very specifically for
  169.         # an answer with TYPE=PTR, CLASS=IN (\000\014\000\001)
  170.         for an in range(ancount):
  171.             pos = skip_name (r, pos)
  172.             if r[pos:pos+4] == '\000\014\000\001':
  173.                 return (
  174.                     unpack_ttl (r,pos+4),
  175.                     unpack_name (r, pos+10)
  176.                     )
  177.             # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA
  178.             pos = pos + 8
  179.             rdlength = (ord(r[pos])<<8) + (ord(r[pos+1]))
  180.             pos = pos + 2 + rdlength
  181.         return 0, None
  182.     else:
  183.         return 0, None
  184.  
  185.  
  186. # This is a UDP (datagram) resolver.
  187.  
  188. #
  189. # It may be useful to implement a TCP resolver.  This would presumably
  190. # give us more reliable behavior when things get too busy.  A TCP
  191. # client would have to manage the connection carefully, since the
  192. # server is allowed to close it at will (the RFC recommends closing
  193. # after 2 minutes of idle time).
  194. #
  195. # Note also that the TCP client will have to prepend each request
  196. # with a 2-byte length indicator (see rfc1035).
  197. #
  198.  
  199. class resolver (asyncore.dispatcher):
  200.     id = counter()
  201.     def __init__ (self, server='127.0.0.1'):
  202.         asyncore.dispatcher.__init__ (self)
  203.         self.create_socket (socket.AF_INET, socket.SOCK_DGRAM)
  204.         self.server = server
  205.         self.request_map = {}
  206.         self.last_reap_time = int(time.time())      # reap every few minutes
  207.  
  208.     def writable (self):
  209.         return 0
  210.  
  211.     def log (self, *args):
  212.         pass
  213.  
  214.     def handle_close (self):
  215.         self.log_info('closing!')
  216.         self.close()
  217.  
  218.     def handle_error (self):      # don't close the connection on error
  219.         (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
  220.         self.log_info(
  221.                 'Problem with DNS lookup (%s:%s %s)' % (t, v, tbinfo),
  222.                 'error')
  223.  
  224.     def get_id (self):
  225.         return (self.id.as_long() % (1<<16))
  226.  
  227.     def reap (self):          # find DNS requests that have timed out
  228.         now = int(time.time())
  229.         if now - self.last_reap_time > 180:        # reap every 3 minutes
  230.             self.last_reap_time = now              # update before we forget
  231.             for k,(host,unpack,callback,when) in self.request_map.items():
  232.                 if now - when > 180:               # over 3 minutes old
  233.                     del self.request_map[k]
  234.                     try:                           # same code as in handle_read
  235.                         callback (host, 0, None)   # timeout val is (0,None) 
  236.                     except:
  237.                         (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
  238.                         self.log_info('%s %s %s' % (t,v,tbinfo), 'error')
  239.  
  240.     def resolve (self, host, callback):
  241.         self.reap()                                # first, get rid of old guys
  242.         self.socket.sendto (
  243.             fast_address_request (host, self.get_id()),
  244.             (self.server, 53)
  245.             )
  246.         self.request_map [self.get_id()] = (
  247.             host, unpack_address_reply, callback, int(time.time()))
  248.         self.id.increment()
  249.  
  250.     def resolve_ptr (self, host, callback):
  251.         self.reap()                                # first, get rid of old guys
  252.         ip = string.split (host, '.')
  253.         ip.reverse()
  254.         ip = string.join (ip, '.') + '.in-addr.arpa'
  255.         self.socket.sendto (
  256.             fast_ptr_request (ip, self.get_id()),
  257.             (self.server, 53)
  258.             )
  259.         self.request_map [self.get_id()] = (
  260.             host, unpack_ptr_reply, callback, int(time.time()))
  261.         self.id.increment()
  262.  
  263.     def handle_read (self):
  264.         reply, whence = self.socket.recvfrom (512)
  265.         # for security reasons we may want to double-check
  266.         # that <whence> is the server we sent the request to.
  267.         id = (ord(reply[0])<<8) + ord(reply[1])
  268.         if self.request_map.has_key (id):
  269.             host, unpack, callback, when = self.request_map[id]
  270.             del self.request_map[id]
  271.             ttl, answer = unpack (reply)
  272.             try:
  273.                 callback (host, ttl, answer)
  274.             except:
  275.                 (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
  276.                 self.log_info('%s %s %s' % ( t,v,tbinfo), 'error')
  277.  
  278. class rbl (resolver):
  279.  
  280.     def resolve_maps (self, host, callback):
  281.         ip = string.split (host, '.')
  282.         ip.reverse()
  283.         ip = string.join (ip, '.') + '.rbl.maps.vix.com'
  284.         self.socket.sendto (
  285.             fast_ptr_request (ip, self.get_id()),
  286.             (self.server, 53)
  287.             )
  288.         self.request_map [self.get_id()] = host, self.check_reply, callback
  289.         self.id.increment()
  290.     
  291.     def check_reply (self, r):
  292.         # we only need to check RCODE.
  293.         rcode = (ord(r[3])&0xf)
  294.         self.log_info('MAPS RBL; RCODE =%02x\n %s' % (rcode, repr(r)))
  295.         return 0, rcode # (ttl, answer)
  296.  
  297.  
  298. class hooked_callback:
  299.     def __init__ (self, hook, callback):
  300.         self.hook, self.callback = hook, callback
  301.  
  302.     def __call__ (self, *args):
  303.         apply (self.hook, args)
  304.         apply (self.callback, args)
  305.  
  306. class caching_resolver (resolver):
  307.     "Cache DNS queries.  Will need to honor the TTL value in the replies"
  308.  
  309.     def __init__ (*args):
  310.         apply (resolver.__init__, args)
  311.         self = args[0]
  312.         self.cache = {}
  313.         self.forward_requests = counter()
  314.         self.reverse_requests = counter()
  315.         self.cache_hits = counter()
  316.  
  317.     def resolve (self, host, callback):
  318.         self.forward_requests.increment()
  319.         if self.cache.has_key (host):
  320.             when, ttl, answer = self.cache[host]
  321.             # ignore TTL for now
  322.             callback (host, ttl, answer)
  323.             self.cache_hits.increment()
  324.         else:
  325.             resolver.resolve (
  326.                 self,
  327.                 host,
  328.                 hooked_callback (
  329.                     self.callback_hook,
  330.                     callback
  331.                     )
  332.                 )
  333.             
  334.     def resolve_ptr (self, host, callback):
  335.         self.reverse_requests.increment()
  336.         if self.cache.has_key (host):
  337.             when, ttl, answer = self.cache[host]
  338.             # ignore TTL for now
  339.             callback (host, ttl, answer)
  340.             self.cache_hits.increment()
  341.         else:
  342.             resolver.resolve_ptr (
  343.                 self,
  344.                 host,
  345.                 hooked_callback (
  346.                     self.callback_hook,
  347.                     callback
  348.                     )
  349.                 )
  350.  
  351.     def callback_hook (self, host, ttl, answer):
  352.         self.cache[host] = time.time(), ttl, answer
  353.  
  354.     SERVER_IDENT = 'Caching DNS Resolver (V%s)' % VERSION
  355.  
  356.     def status (self):
  357.         import status_handler
  358.         import producers
  359.         return producers.simple_producer (
  360.             '<h2>%s</h2>'                    % self.SERVER_IDENT
  361.             + '<br>Server: %s'                % self.server
  362.             + '<br>Cache Entries: %d'        % len(self.cache)
  363.             + '<br>Outstanding Requests: %d' % len(self.request_map)
  364.             + '<br>Forward Requests: %s'    % self.forward_requests
  365.             + '<br>Reverse Requests: %s'    % self.reverse_requests
  366.             + '<br>Cache Hits: %s'            % self.cache_hits
  367.             )
  368.  
  369. #test_reply = """\000\000\205\200\000\001\000\001\000\002\000\002\006squirl\011nightmare\003com\000\000\001\000\001\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\011nightmare\003com\000\000\002\000\001\000\001Q\200\000\002\300\014\3006\000\002\000\001\000\001Q\200\000\015\003ns1\003iag\003net\000\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\300]\000\001\000\001\000\000\350\227\000\004\314\033\322\005"""
  370. # def test_unpacker ():
  371. #     print unpack_address_reply (test_reply)
  372. # import time
  373. # class timer:
  374. #     def __init__ (self):
  375. #         self.start = time.time()
  376. #     def end (self):
  377. #         return time.time() - self.start
  378. # # I get ~290 unpacks per second for the typical case, compared to ~48
  379. # # using dnslib directly.  also, that latter number does not include
  380. # # picking the actual data out.
  381. # def benchmark_unpacker():
  382. #     r = range(1000)
  383. #     t = timer()
  384. #     for i in r:
  385. #         unpack_address_reply (test_reply)
  386. #     print '%.2f unpacks per second' % (1000.0 / t.end())
  387.  
  388. if __name__ == '__main__':
  389.     import sys
  390.     if len(sys.argv) == 1:
  391.         print 'usage: %s [-r] [-s <server_IP>] host [host ...]' % sys.argv[0]
  392.         sys.exit(0)
  393.     elif ('-s' in sys.argv):
  394.         i = sys.argv.index('-s')
  395.         server = sys.argv[i+1]
  396.         del sys.argv[i:i+2]
  397.     else:
  398.         server = '127.0.0.1'
  399.  
  400.     if ('-r' in sys.argv):
  401.         reverse = 1
  402.         i = sys.argv.index('-r')
  403.         del sys.argv[i]
  404.     else:
  405.         reverse = 0
  406.  
  407.     if ('-m' in sys.argv):
  408.         maps = 1
  409.         sys.argv.remove ('-m')
  410.     else:
  411.         maps = 0
  412.  
  413.     if maps:
  414.         r = rbl (server)
  415.     else:
  416.         r = caching_resolver(server)
  417.  
  418.     count = len(sys.argv) - 1
  419.  
  420.     def print_it (host, ttl, answer):
  421.         global count
  422.         print '%s: %s' % (host, answer)
  423.         count = count - 1
  424.         if not count:
  425.             r.close()
  426.  
  427.     for host in sys.argv[1:]:
  428.         if reverse:
  429.             r.resolve_ptr (host, print_it)
  430.         elif maps:
  431.             r.resolve_maps (host, print_it)
  432.         else:
  433.             r.resolve (host, print_it)
  434.  
  435.     # hooked asyncore.loop()
  436.     while asyncore.socket_map:
  437.         asyncore.poll (30.0)
  438.         print 'requests outstanding: %d' % len(r.request_map)
  439.